Skip to content

[PyTorch] Fix FlashAttention 2 head_dim > 192 on sm103 and other architectures#2836

Open
pedramr wants to merge 4 commits intoNVIDIA:mainfrom
pedramr:fix/sm103-flash-attn-allowlist
Open

[PyTorch] Fix FlashAttention 2 head_dim > 192 on sm103 and other architectures#2836
pedramr wants to merge 4 commits intoNVIDIA:mainfrom
pedramr:fix/sm103-flash-attn-allowlist

Conversation

@pedramr
Copy link
Copy Markdown

@pedramr pedramr commented Apr 4, 2026

Description

The head_dim > 192 gate for FlashAttention 2 in get_attention_backend used an exact-match
compute capability allowlist: (8,0), (9,0), (10,0), (12,0). This excluded sm103 (B300/GB300),
sm89 (L40S/RTX 4090), sm86 (A40/RTX 3090), and other valid architectures where flash-attn
supports head_dim up to 256.

This PR replaces the allowlist with a >= sm80 range check, matching flash-attn's own gate:
Dao-AILab/flash-attention@bbb21d6

The sm103 case was validated on hardware with head_dim=256; the remaining architectures appear
to be supported based on flash-attn's >= sm80 guarantee.

Type of change

  • Bug fix (non-breaking change which fixes an issue)

Changes

  • Replace exact-match compute capability allowlist with device_compute_capability < (8, 0) range check
  • Update debug log message from sm80/90/100+ to sm80+

…itectures

Replace the exact-match compute capability allowlist with a >= sm80 range
check, matching flash-attn's own gate:
Dao-AILab/flash-attention@bbb21d6

The allowlist ((8,0), (9,0), (10,0), (12,0)) missed sm103 (B300), sm89
(L40S), sm86 (A40), and others where FA2 supports head_dim up to 256.
The sm103 case was validated on hardware with head_dim=256; the remaining
architectures appear to be supported based on flash-attn's >= sm80 guarantee.

Signed-off-by: Pedram Razavi <pedram.razavi@gmail.com>
@greptile-apps
Copy link
Copy Markdown
Contributor

greptile-apps Bot commented Apr 4, 2026

Greptile Summary

This PR fixes a bug in get_attention_backend where FlashAttention 2 was incorrectly disabled for head_dim > 192 on architectures not in an exact-match allowlist (sm80, sm90, sm100, sm120), excluding valid devices like sm103, sm89, and sm86. The fix replaces the allowlist with the simpler head_dim_qk > 256 or head_dim_qk % 8 != 0 condition, correctly aligning with flash-attn's own >= sm80 support guarantee — which is already enforced earlier in the function at the compute-capability filter (lines 448–451).

Confidence Score: 5/5

Safe to merge — minimal, targeted bug fix with correct logic and no regressions introduced.

The change removes a single, clearly erroneous allowlist condition. The new condition is logically equivalent to flash-attn's own gate given the earlier < sm80 guard already disables FA2 before this point. The dead-code concern flagged in the previous review thread is fully resolved by removing the branch entirely. No new logic is added, and the log message is updated consistently.

No files require special attention.

Important Files Changed

Filename Overview
transformer_engine/pytorch/attention/dot_product_attention/utils.py Removes the exact-match compute-capability allowlist for head_dim > 192 in the FA2 filter, replacing it with the simpler and correct head_dim_qk > 256 or head_dim_qk % 8 != 0 gate; log message updated to match.

Flowchart

%%{init: {'theme': 'neutral'}}%%
flowchart TD
    A[get_attention_backend called] --> B{device_compute_capability < sm80?}
    B -- Yes --> C[use_flash_attention_2 = False]
    B -- No --> D{use_flash_attention_2 AND FA2 installed?}
    D -- No --> G[Skip FA2 head_dim check]
    D -- Yes --> E{head_dim_qk > 256\nOR head_dim_qk % 8 != 0?}
    E -- Yes --> F[use_flash_attention_2 = False\nlog debug message]
    E -- No --> H[FA2 remains enabled]
    C --> I[Continue backend selection]
    F --> I
    G --> I
    H --> I

    style C fill:#f88,stroke:#c00
    style F fill:#f88,stroke:#c00
    style H fill:#8f8,stroke:#090
Loading

Reviews (3): Last reviewed commit: "[pre-commit.ci] auto fixes from pre-comm..." | Re-trigger Greptile

Comment thread transformer_engine/pytorch/attention/dot_product_attention/utils.py Outdated
@cyanguwa
Copy link
Copy Markdown
Collaborator

cyanguwa commented Apr 6, 2026

/te-ci L0

ptrendx added 2 commits April 21, 2026 13:00
Signed-off-by: Przemek Tredak <ptredak@nvidia.com>
Signed-off-by: Przemek Tredak <ptredak@nvidia.com>
@ptrendx
Copy link
Copy Markdown
Member

ptrendx commented Apr 21, 2026

/te-ci pytorch

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants